{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MF7BncmmLBeO"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as tt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how Generative Adversarial Networks (GANs) work, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKsmjLumL5A2"
   },
   "source": [
    "## Dataset: Digits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hSWUnXAYLLif"
   },
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qSP2qiMqMICK"
   },
   "source": [
    "## GANs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "znlDY599MRLy"
   },
   "source": [
    "### Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "x5Q8Iz41LTAj"
   },
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "    def __init__(self, generator_net, z_size):\n",
    "        super(Generator, self).__init__()\n",
    "        \n",
    "        # We need to init the generator neural net.\n",
    "        self.generator_net = generator_net\n",
    "        # We also need to know the size of the latents.\n",
    "        self.z_size = z_size\n",
    "\n",
    "    def generate(self, z):\n",
    "        # Generating for given z is equivalent to applying the neural net.\n",
    "        return self.generator_net(z)\n",
    "\n",
    "    def sample(self, batch_size=16):\n",
    "        # For sampling, we need to sample first latents.\n",
    "        z = torch.randn(batch_size, self.z_size)\n",
    "        return self.generate(z)\n",
    "\n",
    "    def forward(self, z=None):\n",
    "        if z is None:\n",
    "            return self.sample()\n",
    "        else:\n",
    "            return self.generate(z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "m_fYlZJiMSih"
   },
   "source": [
    "### Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xJYhNXulLUhz"
   },
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, discriminator_net):\n",
    "        super(Discriminator, self).__init__()\n",
    "        # We need to init the discriminator neural net.\n",
    "        self.discriminator_net = discriminator_net\n",
    "\n",
    "    def forward(self, x):\n",
    "        # The forward pass is just about applying the neural net.\n",
    "        return self.discriminator_net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sCoRJeYpMToZ"
   },
   "source": [
    "### GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GRYA6JA4LWEC"
   },
   "outputs": [],
   "source": [
    "class GAN(nn.Module):\n",
    "    def __init__(self, generator, discriminator, EPS=1.e-5):\n",
    "        super(GAN, self).__init__()\n",
    "\n",
    "        print('GAN by JT.')\n",
    "        \n",
    "        # To put everything together, we need the generator and \n",
    "        # the discriminator. NOTE: Both are intanced of classes!\n",
    "        self.generator = generator\n",
    "        self.discriminator = discriminator\n",
    "        \n",
    "        # For numerical issue, we introduce a small epsilon.\n",
    "        self.EPS = EPS\n",
    "\n",
    "    def forward(self, x_real, reduction='avg', mode='discriminator'):\n",
    "        # The forward pass calculates the adversarial loss.\n",
    "        # More specifically, either its part for the generator or\n",
    "        #  the part for the discriminator.\n",
    "        if mode == 'generator':\n",
    "            # For the generator, we first sample FAKE data.\n",
    "            x_fake_gen = self.generator.sample(x_real.shape[0])\n",
    "\n",
    "            # Then, we calculate outputs of the discriminator for the FAKE data.\n",
    "            # NOTE: We clamp here for the numerical stability later on.\n",
    "            d_fake = torch.clamp(self.discriminator(x_fake_gen), self.EPS, 1. - self.EPS)\n",
    "            \n",
    "            # The loss for the generator is log(1 - D(G(z))).\n",
    "            loss = torch.log(1. - d_fake)\n",
    "\n",
    "        elif mode == 'discriminator':\n",
    "            # For the discriminator, we first sample FAKE data.\n",
    "            x_fake_gen = self.generator.sample(x_real.shape[0])\n",
    "\n",
    "            # Then, we calculate outputs of the discriminator for the FAKE data.\n",
    "            # NOTE: We clamp for the numerical stability later on.\n",
    "            d_fake = torch.clamp(self.discriminator(x_fake_gen), self.EPS, 1. - self.EPS)\n",
    "            \n",
    "            # Moreover, we calculate outputs of the discriminator for the REAL data.\n",
    "            # NOTE: We clamp for... the numerical stability (again).\n",
    "            d_real = torch.clamp(self.discriminator(x_real), self.EPS, 1. - self.EPS)\n",
    "\n",
    "            # The final loss for the discriminator is log(1 - D(G(z))) + log D(x).\n",
    "            # NOTE: We take the minus sign because we MAXIMIZE the adversarial loss wrt \n",
    "            # discriminator, so we MINIMIZE the negative adversarial loss wrt discriminator.\n",
    "            loss = -(torch.log(d_real) + torch.log(1. - d_fake))\n",
    "        \n",
    "        if reduction == 'sum':\n",
    "            return loss.sum()\n",
    "        else:\n",
    "            return loss.mean()\n",
    "\n",
    "    def sample(self, batch_size=64):\n",
    "        return self.generator.sample(batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vUoPkTmrMVnx"
   },
   "source": [
    "## Evaluation and Training functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JvwmRoi7MVto"
   },
   "source": [
    "**Evaluation step, sampling and curve plotting**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JHx4RIqDLZe9"
   },
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss_gen = 0.\n",
    "    loss_dis = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t_gen = model_best.forward(test_batch, reduction='sum', mode='generator')\n",
    "        loss_t_dis = model_best.forward(test_batch, reduction='sum', mode='discriminator')\n",
    "        \n",
    "        loss_gen = loss_gen + loss_t_gen.item()\n",
    "        loss_dis = loss_dis + loss_t_dis.item()\n",
    "        \n",
    "        N = N + test_batch.shape[0]\n",
    "    \n",
    "    loss_gen = loss_gen / N\n",
    "    loss_dis = loss_dis / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: gen={loss_gen}, dis={loss_dis}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val gen={loss_gen}, val dis={loss_dis}')\n",
    "\n",
    "    return loss_gen, loss_dis\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = next(iter(test_loader)).detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name=''):\n",
    "    x = next(iter(data_loader)).detach().numpy()\n",
    "\n",
    "    # GENERATIONS-------\n",
    "    model_best = torch.load(name + '.model')\n",
    "    model_best.eval()\n",
    "\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = model_best.sample(num_x * num_y)\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def plot_curve(name, nll_val, xaxis='epochs', yaxis='nll'):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel(xaxis)\n",
    "    plt.ylabel(yaxis)\n",
    "    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umU3VYKzMbDt"
   },
   "source": [
    "**Training step**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NxkUZ1xVLbm_"
   },
   "outputs": [],
   "source": [
    "def training(name, num_epochs, model, optimizer_gen, optimizer_dis, training_loader, val_loader):\n",
    "    gen_val = []\n",
    "    dis_val = []\n",
    "    best_loss = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            \n",
    "            # -Discriminator\n",
    "            loss_dis = model.forward(batch, mode='discriminator')\n",
    "                        \n",
    "            optimizer_dis.zero_grad()\n",
    "            optimizer_gen.zero_grad()\n",
    "            loss_dis.backward(retain_graph=True)\n",
    "            optimizer_dis.step()\n",
    "            \n",
    "            # -Generator\n",
    "            loss_gen = model.forward(batch, mode='generator')\n",
    "            \n",
    "            optimizer_dis.zero_grad()\n",
    "            optimizer_gen.zero_grad()\n",
    "            loss_gen.backward(retain_graph=True)\n",
    "            optimizer_gen.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val_gen, loss_val_dis = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        gen_val.append(loss_val_gen)  # save for plotting\n",
    "        dis_val.append(loss_val_dis)  # save for plotting\n",
    "        \n",
    "        torch.save(model, name + '.model')\n",
    "        samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "\n",
    "    gen_val = np.asarray(gen_val)\n",
    "    dis_val = np.asarray(dis_val)\n",
    "\n",
    "    return gen_val, dis_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0BXJ9dN0MinB"
   },
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KsF7f-Q-MkWu"
   },
   "source": [
    "**Initialize datasets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms_train = tt.Compose( [tt.Lambda(lambda x: 2*(x / 17.)-1.),\n",
    "                              tt.Lambda(lambda x: torch.from_numpy(x)),\n",
    "                              tt.Lambda(lambda x: x + 0.03 * torch.randn_like(x))\n",
    "                              ])\n",
    "\n",
    "transforms_val  = tt.Compose( [tt.Lambda(lambda x: 2*(x / 17.)-1.),\n",
    "                               tt.Lambda(lambda x: torch.from_numpy(x)),\n",
    "                               ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fqZKMNM0LdQ1"
   },
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train', transforms=transforms_train)\n",
    "val_data = Digits(mode='val', transforms=transforms_val)\n",
    "test_data = Digits(mode='test', transforms=transforms_val)\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=64, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6lEKUznpMns7"
   },
   "source": [
    "**Hyperparameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ANQo7LrGLjIN"
   },
   "outputs": [],
   "source": [
    "D = 64   # input dimension\n",
    "L = 16 # number of latents\n",
    "M = 128  # the number of neurons in scale (s) and translation (t) nets\n",
    "\n",
    "lr_gen = 3e-4 # learning rate\n",
    "lr_dis = 3e-4 # learning rate\n",
    "num_epochs = 500 # max. number of epochs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-7APXeunMrDh"
   },
   "source": [
    "**Creating a folder for results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bjSUn1eWLkWm"
   },
   "outputs": [],
   "source": [
    "if not(os.path.exists('results')):\n",
    "    os.mkdir('results')\n",
    "name = 'gan_' + str(L)\n",
    "result_dir ='results/' + name + '/'\n",
    "if not(os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hpwm6LWUMulQ"
   },
   "source": [
    "**Initializing the model: (i) determining the conditional likelihood distribution, (ii) defininig encoder and decoder nets, and a prior**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FrnNsCqQLmK3",
    "outputId": "5f0cf2b1-0a96-4f5c-da9e-f78f909a5259"
   },
   "outputs": [],
   "source": [
    "# First, we initialize the generator and the discriminator\n",
    "# -generator\n",
    "generator_net = nn.Sequential(nn.Linear(L, M), nn.ReLU(),\n",
    "                              nn.Linear(M, D), nn.Tanh())\n",
    "\n",
    "generator = Generator(generator_net, z_size=L)\n",
    "\n",
    "# -discriminator\n",
    "discriminator_net = nn.Sequential(nn.Linear(D, M), nn.ReLU(),\n",
    "                                  nn.Linear(M, 1), nn.Sigmoid())\n",
    "\n",
    "discriminator = Discriminator(discriminator_net)\n",
    "\n",
    "# Eventually, we initialize the full model\n",
    "model = GAN(generator=generator, discriminator=discriminator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3SzTemY3NSxO"
   },
   "source": [
    "**Optimizer - here we use Adamax**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R9TZtLVtLoWc"
   },
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer_gen = torch.optim.Adam(model.generator.parameters(), lr=lr_gen)\n",
    "optimizer_dis = torch.optim.Adam(model.discriminator.parameters(), lr=lr_dis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dNf__W_ONVHA"
   },
   "source": [
    "**Training loop**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KhqHgluGLqIC",
    "outputId": "c52fa1e4-3376-4bff-9f87-6f03613c4e42"
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "gen_val, dis_val = training(name=result_dir + name, \n",
    "                            num_epochs=num_epochs, \n",
    "                            model=model, \n",
    "                            optimizer_gen=optimizer_gen,\n",
    "                            optimizer_dis=optimizer_dis,\n",
    "                            training_loader=training_loader, \n",
    "                            val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-3XTxgEcNXfp"
   },
   "source": [
    "**The final evaluation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "okK1mV_-LrRU",
    "outputId": "4664693f-742d-4453-94cf-d051d2efa9be"
   },
   "outputs": [],
   "source": [
    "test_gen_loss, test_dis_loss = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write(str(test_gen_loss))\n",
    "f.write('\\n')\n",
    "f.write(str(test_dis_loss))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "\n",
    "plot_curve(result_dir + name + '_gen', gen_val, yaxis='$log(1-D(G(z))$')\n",
    "plot_curve(result_dir + name + '_dis', dis_val, yaxis='$-(log(1-D(G(z)) + log(D(x)))$')\n",
    "\n",
    "samples_generated(result_dir + name, test_loader, extra_name='FINAL')"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "vae_priors.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
